package com.gitee.fastmybatis.core;

import com.gitee.fastmybatis.core.ext.MapperRunner;
import com.gitee.fastmybatis.core.ext.info.EntityInfo;
import com.gitee.fastmybatis.core.ext.spi.MapperBuilder;
import com.gitee.fastmybatis.core.ext.spi.SpiContext;
import com.gitee.fastmybatis.core.mapper.Mapper;
import com.gitee.fastmybatis.core.support.Dialect;
import com.gitee.fastmybatis.core.update.ModifyAttrsRecordProxyFactory;
import com.gitee.fastmybatis.core.update.Reflectors;
import com.gitee.fastmybatis.core.update.UpdateWrapper;
import com.gitee.fastmybatis.core.util.ClassUtil;
import com.gitee.fastmybatis.core.util.ConvertUtil;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.reflection.Reflector;

import java.lang.reflect.Field;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * @author thc
 */
public class FastmybatisContext {

    private static final Log LOG = LogFactory.getLog(FastmybatisContext.class);

    // entityClass
    private static final Map<Class<?>, EntityInfo> ENTITY_INFO_MAP = new HashMap<>(16);

    private static final ThreadLocal<List<EqualColumn>> equalColumnsLocal = new ThreadLocal<>();

    private static Object applicationContext;

    /**
     * 设置逻辑删除时额外需要更新的字段
     *
     * @param equalColumns 待更新字段
     * @see com.gitee.fastmybatis.core.mapper.DeleteMapper#deleteByIds(Collection, EqualColumn...) deleteByIds
     * @see com.gitee.fastmybatis.core.mapper.DeleteMapper#deleteByColumn(String, Object, EqualColumn...) deleteByColumn
     */
    public static void setDeleteSetBlock(List<EqualColumn> equalColumns) {
        equalColumnsLocal.set(equalColumns);
    }

    public static List<EqualColumn> getDeleteSetBlock() {
        return equalColumnsLocal.get();
    }

    public static void setApplicationContext(Object ctx) {
        if (applicationContext == null) {
            applicationContext = ctx;
        }
    }

    public static <T extends Mapper> MapperRunner<T> getCrudMapperRunner(Class<?> entityClass) {
        Class<T> mapperClass = (Class<T>) MybatisContext.getMapperClassByEntityClass(entityClass);
        return getMapperRunner(mapperClass);
    }

    public static <T extends Mapper> MapperRunner<T> getMapperRunner(Class<T> mapperClass) {
        MapperBuilder mapperBuilder = SpiContext.getMapperBuilder();
        return mapperBuilder.getMapperRunner(mapperClass, applicationContext);
    }

    public static EntityInfo getEntityInfo(Class<?> entityClass) {
        return ENTITY_INFO_MAP.get(entityClass);
    }

    public static void setEntityInfo(Class<?> entityClass, EntityInfo entityInfo) {
        ENTITY_INFO_MAP.put(entityClass, entityInfo);
    }

    /**
     * 根据实体类获取Dialect
     *
     * @param entityClass 实体类
     * @return 返回Dialect，没有则返回未知Dialect
     */
    public static Dialect getDialect(Class<?> entityClass) {
        EntityInfo entityInfo = getEntityInfo(entityClass);
        if (entityInfo == null) {
            return Dialect.UNKNOWN;
        }
        String dialect = entityInfo.getDialect();
        return Dialect.of(dialect);
    }

    /**
     * 获取实体类对应的数据库主键名称
     *
     * @param entityClass 实体类class
     * @return 返回数据库主键名称
     */
    public static String getPkColumnName(Class<?> entityClass) {
        EntityInfo entityInfo = getEntityInfo(entityClass);
        if (entityInfo == null) {
            return null;
        }
        return entityInfo.getPkColumnName();
    }

    /**
     * 获取实体类对应的数据库主键名称
     *
     * @param mapperClass mapperClass
     * @return 返回数据库主键名称
     */
    public static String getPkColumnNameFromMapper(Class<? extends Mapper> mapperClass) {
        Class<?> entityClass = MybatisContext.getEntityClassByMapperClass(mapperClass);
        return getPkColumnName(entityClass);
    }

    /**
     * 获取实体类主键对应的JAVA字段名称
     *
     * @param entityClass 实体类class
     * @return 返回主键java字段名称
     */
    public static String getPkJavaName(Class<?> entityClass) {
        EntityInfo entityInfo = getEntityInfo(entityClass);
        if (entityInfo == null) {
            return null;
        }
        return entityInfo.getPkJavaName();
    }

    /**
     * 获取主键值
     *
     * @param entity 实体类对象
     * @return 返回主键值
     */
    public static Object getPkValue(Object entity) {
        if (entity == null) {
            return null;
        }


        if (entity instanceof UpdateWrapper) {
            return getPkValueByGetter(entity);
        }

        Class<?> entityClass = entity.getClass();
        String pkJavaName = getPkJavaName(entityClass);
        Field field = ClassUtil.findField(entityClass, pkJavaName);
        if (field == null) {
            return null;
        }
        ClassUtil.makeAccessible(field);
        try {
            return field.get(entity);
        } catch (IllegalAccessException e) {
            LOG.error("反射出错", e);
            return null;
        }
    }

    private static Object getPkValueByGetter(Object entity) {
        Class<?> entityClass = entity.getClass();
        if (entity instanceof UpdateWrapper) {
            entityClass = ModifyAttrsRecordProxyFactory.getSrcClass(entity.getClass());
        }

        String pkJavaName = getPkJavaName(entityClass);
        Field field = ClassUtil.findField(entityClass, pkJavaName);
        if (field == null) {
            return null;
        }


        Reflector reflector = Reflectors.of(entityClass);
        try {
            return reflector.getGetInvoker(field.getName()).invoke(entity, null);
        } catch (Exception ignored) {
            // do nothing here.
        }

        return null;
    }

    public static <T> void setPkValue(T entity, Object id) {
        setPkValue(entity, entity.getClass(), id);
    }

    public static void setPkValue(Object entity, Class<?> entityClass, Object id) {

        String pkJavaName = getPkJavaName(entityClass);
        Field field = ClassUtil.findField(entityClass, pkJavaName);
        if (field == null) {
            return;
        }

        Reflector reflector = Reflectors.of(entityClass);
        try {
            reflector.getSetInvoker(field.getName()).invoke(entity, new Object[]{ConvertUtil.convert(id, field.getType())});
        } catch (Exception ignored) {
            // do nothing here.
        }


    }

    public static Set<String> getIgnoreProperties(Class<?> srcClass, Set<String> includeProperties) {
        if (null == includeProperties) {
            includeProperties = Collections.emptySet();
        }


        Set<String> IgnoreProperties = new HashSet<>();

        Field[] fields = srcClass.getDeclaredFields();
        for (Field field : fields) {
            if (!includeProperties.contains(field.getName())) {
                IgnoreProperties.add(field.getName());
            }
        }


        return IgnoreProperties;

    }
}
